{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentence Classification with Convolution Neural Networks\n",
    "[Paper](https://arxiv.org/pdf/1408.5882.pdf): Convolutional Neural Networks for Sentence Classification by Yoon Kim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\thushan\\documents\\python_virtualenvs\\tensorflow_venv\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "# These are all the modules we'll be using later. Make sure you can import them\n",
    "# before proceeding further.\n",
    "%matplotlib inline\n",
    "from __future__ import print_function\n",
    "import collections\n",
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import zipfile\n",
    "from matplotlib import pylab\n",
    "from six.moves import range\n",
    "from six.moves.urllib.request import urlretrieve\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading and Checking the Dataset\n",
    "This [dataset](Dataset: http://cogcomp.cs.illinois.edu/Data/QA/QC/) is composed of questions as inputs and their respective type as the output. For example, (e.g. Who was Abraham Lincon?) and the output or label would be Human."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "question-classif-data\\train_1000.label\n",
      "Found and verified question-classif-data\\train_1000.label\n",
      "question-classif-data\\TREC_10.label\n",
      "Found and verified question-classif-data\\TREC_10.label\n"
     ]
    }
   ],
   "source": [
    "url = 'http://cogcomp.org/Data/QA/QC/'\n",
    "dir_name = 'question-classif-data'\n",
    "\n",
    "def maybe_download(dir_name, filename, expected_bytes):\n",
    "  \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n",
    "  if not os.path.exists(dir_name):\n",
    "        os.mkdir(dir_name)\n",
    "  if not os.path.exists(os.path.join(dir_name,filename)):\n",
    "    filename, _ = urlretrieve(url + filename, os.path.join(dir_name,filename))\n",
    "  print(os.path.join(dir_name,filename))\n",
    "  statinfo = os.stat(os.path.join(dir_name,filename))\n",
    "  if statinfo.st_size == expected_bytes:\n",
    "    print('Found and verified %s' % os.path.join(dir_name,filename))\n",
    "  else:\n",
    "    print(statinfo.st_size)\n",
    "    raise Exception(\n",
    "      'Failed to verify ' + os.path.join(dir_name,filename) + '. Can you get to it with a browser?')\n",
    "  return filename\n",
    "\n",
    "filename = maybe_download(dir_name, 'train_1000.label', 60774)\n",
    "test_filename = maybe_download(dir_name, 'TREC_10.label',23354)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files found and verified.\n"
     ]
    }
   ],
   "source": [
    "# Check the existence of files\n",
    "filenames = ['train_1000.label','TREC_10.label']\n",
    "num_files = len(filenames)\n",
    "for i in range(len(filenames)):\n",
    "    file_exists = os.path.isfile(os.path.join(dir_name,filenames[i]))\n",
    "    assert file_exists\n",
    "print('Files found and verified.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading and Preprocessing Data\n",
    "Below we load the text into the program and do some simple preprocessing on data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Processing file question-classif-data\\train_1000.label\n",
      "\tQuestion 0: ['manner', 'how', 'did', 'serfdom', 'develop', 'in', 'and', 'then', 'leave', 'russia', '?']\n",
      "\tLabel 0: DESC\n",
      "\n",
      "\tQuestion 1: ['cremat', 'what', 'films', 'featured', 'the', 'character', 'popeye', 'doyle', '?']\n",
      "\tLabel 1: ENTY\n",
      "\n",
      "\tQuestion 2: ['manner', 'how', 'can', 'i', 'find', 'a', 'list', 'of', 'celebrities', \"'\", 'real', 'names', '?']\n",
      "\tLabel 2: DESC\n",
      "\n",
      "\tQuestion 3: ['animal', 'what', 'fowl', 'grabs', 'the', 'spotlight', 'after', 'the', 'chinese', 'year', 'of', 'the', 'monkey', '?']\n",
      "\tLabel 3: ENTY\n",
      "\n",
      "\tQuestion 4: ['exp', 'what', 'is', 'the', 'full', 'form', 'of', '.com', '?']\n",
      "\tLabel 4: ABBR\n",
      "\n",
      "\n",
      "Processing file question-classif-data\\TREC_10.label\n",
      "\tQuestion 0: ['manner', 'how', 'did', 'serfdom', 'develop', 'in', 'and', 'then', 'leave', 'russia', '?']\n",
      "\tLabel 0: DESC\n",
      "\n",
      "\tQuestion 1: ['cremat', 'what', 'films', 'featured', 'the', 'character', 'popeye', 'doyle', '?']\n",
      "\tLabel 1: ENTY\n",
      "\n",
      "\tQuestion 2: ['manner', 'how', 'can', 'i', 'find', 'a', 'list', 'of', 'celebrities', \"'\", 'real', 'names', '?']\n",
      "\tLabel 2: DESC\n",
      "\n",
      "\tQuestion 3: ['animal', 'what', 'fowl', 'grabs', 'the', 'spotlight', 'after', 'the', 'chinese', 'year', 'of', 'the', 'monkey', '?']\n",
      "\tLabel 3: ENTY\n",
      "\n",
      "\tQuestion 4: ['exp', 'what', 'is', 'the', 'full', 'form', 'of', '.com', '?']\n",
      "\tLabel 4: ABBR\n",
      "\n",
      "Max Sentence Length: 33\n",
      "\n",
      "Normalizing all sentences to same length\n"
     ]
    }
   ],
   "source": [
    "# Records the maximum length of the sentences\n",
    "# as we need to pad shorter sentences accordingly\n",
    "max_sent_length = 0 \n",
    "\n",
    "def read_data(filename):\n",
    "  '''\n",
    "  Read data from a file with given filename\n",
    "  Returns a list of strings where each string is a lower case word\n",
    "  '''\n",
    "  global max_sent_length\n",
    "  questions = []\n",
    "  labels = []\n",
    "  with open(filename,'r',encoding='latin-1') as f:        \n",
    "    for row in f:\n",
    "        row_str = row.split(\":\")\n",
    "        lb,q = row_str[0],row_str[1]\n",
    "        q = q.lower()\n",
    "        labels.append(lb)\n",
    "        questions.append(q.split())        \n",
    "        if len(questions[-1])>max_sent_length:\n",
    "            max_sent_length = len(questions[-1])\n",
    "  return questions,labels\n",
    "\n",
    "# Process train and Test data\n",
    "for i in range(num_files):    \n",
    "    print('\\nProcessing file %s'%os.path.join(dir_name,filenames[i]))\n",
    "    if i==0:\n",
    "        # Processing training data\n",
    "        train_questions,train_labels = read_data(os.path.join(dir_name,filenames[i]))\n",
    "        # Making sure we got all the questions and corresponding labels\n",
    "        assert len(train_questions)==len(train_labels)\n",
    "    elif i==1:\n",
    "        # Processing testing data\n",
    "        test_questions,test_labels = read_data(os.path.join(dir_name,filenames[i]))\n",
    "        # Making sure we got all the questions and corresponding labels.\n",
    "        assert len(test_questions)==len(test_labels)\n",
    "        \n",
    "    # Print some data to see everything is okey\n",
    "    for j in range(5):\n",
    "        print('\\tQuestion %d: %s' %(j,train_questions[j]))\n",
    "        print('\\tLabel %d: %s\\n'%(j,train_labels[j]))\n",
    "        \n",
    "print('Max Sentence Length: %d'%max_sent_length)\n",
    "print('\\nNormalizing all sentences to same length')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Padding Shorter Sentences\n",
    "We use padding to pad short sentences so that all the sentences are of the same length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train questions padded\n",
      "\n",
      "Test questions padded\n",
      "\n",
      "Sample test question: %s ['dist', 'how', 'far', 'is', 'it', 'from', 'denver', 'to', 'aspen', '?', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']\n"
     ]
    }
   ],
   "source": [
    "# Padding training data\n",
    "for qi,que in enumerate(train_questions):\n",
    "    for _ in range(max_sent_length-len(que)):\n",
    "        que.append('PAD')\n",
    "    assert len(que)==max_sent_length\n",
    "    train_questions[qi] = que\n",
    "print('Train questions padded')\n",
    "\n",
    "# Padding testing data\n",
    "for qi,que in enumerate(test_questions):\n",
    "    for _ in range(max_sent_length-len(que)):\n",
    "        que.append('PAD')\n",
    "    assert len(que)==max_sent_length\n",
    "    test_questions[qi] = que\n",
    "print('\\nTest questions padded')  \n",
    "\n",
    "# Printing a test question to see if everything is correct\n",
    "print('\\nSample test question: %s',test_questions[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building the Dictionaries\n",
    "Builds the following. To understand each of these elements, let us also assume the text \"I like to go to school\"\n",
    "\n",
    "* `dictionary`: maps a string word to an ID (e.g. {I:0, like:1, to:2, go:3, school:4})\n",
    "* `reverse_dictionary`: maps an ID to a string word (e.g. {0:I, 1:like, 2:to, 3:go, 4:school}\n",
    "* `count`: List of list of (word, frequency) elements (e.g. [(I,1),(like,1),(to,2),(go,1),(school,1)]\n",
    "* `data` : Contain the string of text we read, where string words are replaced with word IDs (e.g. [0, 1, 2, 3, 2, 4])\n",
    "\n",
    "We do not replace rare words with \"UNK\" because the vocabulary is already quite small."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49500 Words found.\n",
      "Found 3369 words in the vocabulary. \n",
      "All words (count) [('PAD', 34407), ('?', 1454), ('the', 999), ('what', 963), ('is', 587)]\n",
      "\n",
      "0th entry in dictionary: %s PAD\n",
      "\n",
      "Sample data [38, 12, 19, 1977, 1118, 6, 28, 2230, 3107, 686, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Sample data [44, 3, 881, 2852, 2, 173, 2113, 2996, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Vocabulary:  3369\n",
      "\n",
      "Number of training questions:  1000\n",
      "Number of testing questions:  500\n"
     ]
    }
   ],
   "source": [
    "def build_dataset(questions):\n",
    "    words = []\n",
    "    data_list = []\n",
    "    count = []\n",
    "    \n",
    "    # First create a large list with all the words in all the questions\n",
    "    for d in questions:\n",
    "        words.extend(d)\n",
    "    print('%d Words found.'%len(words))    \n",
    "    print('Found %d words in the vocabulary. '%len(collections.Counter(words).most_common()))\n",
    "    \n",
    "    # Sort words by there frequency\n",
    "    count.extend(collections.Counter(words).most_common())\n",
    "    \n",
    "    # Create an ID for each word by giving the current length of the dictionary\n",
    "    # And adding that item to the dictionary\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    \n",
    "    # Traverse through all the text and \n",
    "    # replace the string words with the ID \n",
    "    # of the word found at that index\n",
    "    for d in questions:\n",
    "        data = list()\n",
    "        for word in d:\n",
    "            index = dictionary[word]        \n",
    "            data.append(index)\n",
    "            \n",
    "        data_list.append(data)\n",
    "        \n",
    "    reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) \n",
    "    \n",
    "    return data_list, count, dictionary, reverse_dictionary\n",
    "\n",
    "# Create a dataset with both train and test questions\n",
    "all_questions = list(train_questions)\n",
    "all_questions.extend(test_questions)\n",
    "\n",
    "# Use the above created dataset to build the vocabulary\n",
    "all_question_ind, count, dictionary, reverse_dictionary = build_dataset(all_questions)\n",
    "\n",
    "# Print some statistics about the processed data\n",
    "print('All words (count)', count[:5])\n",
    "print('\\n0th entry in dictionary: %s',reverse_dictionary[0])\n",
    "print('\\nSample data', all_question_ind[0])\n",
    "print('\\nSample data', all_question_ind[1])\n",
    "print('\\nVocabulary: ',len(dictionary))\n",
    "vocabulary_size = len(dictionary)\n",
    "\n",
    "print('\\nNumber of training questions: ',len(train_questions))\n",
    "print('Number of testing questions: ',len(test_questions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Batches of Data\n",
    "Below I show the code to generate a batch of data from a given set of questions and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample batch labels\n",
      "[3 4 3 4 5 2 2 2 3 2 0 3 2 2 4 1]\n",
      "[3 0 3 3 0 4 2 3 3 4 2 1 4 1 5 4]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 16 # We process 16 questions at a time\n",
    "sent_length = max_sent_length\n",
    "\n",
    "num_classes = 6 # Number of classes\n",
    "# All the types of question that are in the dataset\n",
    "all_labels = ['NUM','LOC','HUM','DESC','ENTY','ABBR'] \n",
    "\n",
    "class BatchGenerator(object):\n",
    "    '''\n",
    "    Generates a batch of data\n",
    "    '''\n",
    "    def __init__(self,batch_size,questions,labels):\n",
    "        self.questions = questions\n",
    "        self.labels = labels\n",
    "        self.text_size = len(questions)\n",
    "        self.batch_size = batch_size\n",
    "        self.data_index = 0\n",
    "        assert len(self.questions)==len(self.labels)\n",
    "        \n",
    "    def generate_batch(self):\n",
    "        '''\n",
    "        Data generation function. This outputs two matrices\n",
    "        inputs: a batch of questions where each question is a tensor of size\n",
    "        [sent_length, vocabulary_size] with each word one-hot-encoded\n",
    "        labels_ohe: one-hot-encoded labels corresponding to the questions in inputs\n",
    "        '''\n",
    "        global sent_length,num_classes\n",
    "        global dictionary, all_labels\n",
    "        \n",
    "        # Numpy arrays holding input and label data\n",
    "        inputs = np.zeros((self.batch_size,sent_length,vocabulary_size),dtype=np.float32)\n",
    "        labels_ohe = np.zeros((self.batch_size,num_classes),dtype=np.float32)\n",
    "        \n",
    "        # When we reach the end of the dataset\n",
    "        # start from beginning\n",
    "        if self.data_index + self.batch_size >= self.text_size:\n",
    "            self.data_index = 0\n",
    "            \n",
    "        # For each question in the dataset\n",
    "        for qi,que in enumerate(self.questions[self.data_index:self.data_index+self.batch_size]):\n",
    "            # For each word in the question\n",
    "            for wi,word in enumerate(que): \n",
    "                # Set the element at the word ID index to 1\n",
    "                # this gives the one-hot-encoded vector of that word\n",
    "                inputs[qi,wi,dictionary[word]] = 1.0\n",
    "            \n",
    "            # Set the index corrsponding to that particular class to 1\n",
    "            labels_ohe[qi,all_labels.index(self.labels[self.data_index + qi])] = 1.0\n",
    "        \n",
    "        # Update the data index to get the next batch of data\n",
    "        self.data_index = (self.data_index + self.batch_size)%self.text_size\n",
    "            \n",
    "        return inputs,labels_ohe\n",
    "    \n",
    "    def return_index(self):\n",
    "        # Get the current index of data\n",
    "        return self.data_index\n",
    "\n",
    "# Test our batch generator\n",
    "sample_gen = BatchGenerator(batch_size,train_questions,train_labels)\n",
    "# Generate a single batch\n",
    "sample_batch_inputs,sample_batch_labels = sample_gen.generate_batch()\n",
    "# Generate another batch\n",
    "sample_batch_inputs_2,sample_batch_labels_2 = sample_gen.generate_batch()\n",
    "\n",
    "# Make sure that we infact have the question 0 as the 0th element of our batch\n",
    "assert np.all(np.asarray([dictionary[w] for w in train_questions[0]],dtype=np.int32) \n",
    "              == np.argmax(sample_batch_inputs[0,:,:],axis=1))\n",
    "\n",
    "# Print some data labels we obtained\n",
    "print('Sample batch labels')\n",
    "print(np.argmax(sample_batch_labels,axis=1))\n",
    "print(np.argmax(sample_batch_labels_2,axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentence Classifying Convolution Neural Network\n",
    "We are going to implement a very simple CNN to classify sentences. However you will see that even with this simple structure we achieve good accuracies. Our CNN will have one layer (with 3 different parallel layers). This will be followed by a pooling-over-time layer and finally a fully connected layer that produces the logits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining hyperparameters and inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "batch_size = 32\n",
    "# Different filter sizes we use in a single convolution layer\n",
    "filter_sizes = [3,5,7] \n",
    "\n",
    "# inputs and labels\n",
    "sent_inputs = tf.placeholder(shape=[batch_size,sent_length,vocabulary_size],dtype=tf.float32,name='sentence_inputs')\n",
    "sent_labels = tf.placeholder(shape=[batch_size,num_classes],dtype=tf.float32,name='sentence_labels')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Model Parameters\n",
    "Our model has following parameters.\n",
    "* 3 sets of convolution layer weights and biases (one for each parallel layer)\n",
    "* 1 fully connected output layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 3 filters with different context window sizes (3,5,7)\n",
    "# Each of this filter spans the full one-hot-encoded length of each word and the context window width\n",
    "\n",
    "# Weights of the first parallel layer\n",
    "w1 = tf.Variable(tf.truncated_normal([filter_sizes[0],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_1')\n",
    "b1 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_1')\n",
    "\n",
    "# Weights of the second parallel layer\n",
    "w2 = tf.Variable(tf.truncated_normal([filter_sizes[1],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_2')\n",
    "b2 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_2')\n",
    "\n",
    "# Weights of the third parallel layer\n",
    "w3 = tf.Variable(tf.truncated_normal([filter_sizes[2],vocabulary_size,1],stddev=0.02,dtype=tf.float32),name='weights_3')\n",
    "b3 = tf.Variable(tf.random_uniform([1],0,0.01,dtype=tf.float32),name='bias_3')\n",
    "\n",
    "# Fully connected layer\n",
    "w_fc1 = tf.Variable(tf.truncated_normal([len(filter_sizes),num_classes],stddev=0.5,dtype=tf.float32),name='weights_fulcon_1')\n",
    "b_fc1 = tf.Variable(tf.random_uniform([num_classes],0,0.01,dtype=tf.float32),name='bias_fulcon_1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Inference of the CNN\n",
    "Here we define the CNN inference logic. First compute the convolution output for each parallel layer within the convolution layer. Then perform pooling-over-time over all the convolution outputs. Finally feed the output of the pooling layer to a fully connected layer to obtain the output logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Calculate the output for all the filters with a stride 1\n",
    "# We use relu activation as the activation function\n",
    "h1_1 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w1,stride=1,padding='SAME') + b1)\n",
    "h1_2 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w2,stride=1,padding='SAME') + b2)\n",
    "h1_3 = tf.nn.relu(tf.nn.conv1d(sent_inputs,w3,stride=1,padding='SAME') + b3)\n",
    "\n",
    "# Pooling over time operation\n",
    "\n",
    "# This is doing the max pooling. Thereare two options to do the max pooling\n",
    "# 1. Use tf.nn.max_pool operation on a tensor made by concatenating h1_1,h1_2,h1_3 and converting that tensor to 4D\n",
    "# (Because max_pool takes a tensor of rank >= 4 )\n",
    "# 2. Do the max pooling separately for each filter output and combine them using tf.concat \n",
    "# (this is the one used in the code)\n",
    "\n",
    "h2_1 = tf.reduce_max(h1_1,axis=1)\n",
    "h2_2 = tf.reduce_max(h1_2,axis=1)\n",
    "h2_3 = tf.reduce_max(h1_3,axis=1)\n",
    "\n",
    "h2 = tf.concat([h2_1,h2_2,h2_3],axis=1)\n",
    "\n",
    "# Calculate the fully connected layer output (no activation)\n",
    "# Note: since h2 is 2d [batch_size,number of parallel filters] \n",
    "# reshaping the output is not required as it usually do in CNNs\n",
    "logits = tf.matmul(h2,w_fc1) + b_fc1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Loss and the Optimizer\n",
    "We compute the cross entropy loss and use the momentum optimizer (which works better than standard gradient descent) to optimize our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Loss (Cross-Entropy)\n",
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=sent_labels,logits=logits))\n",
    "\n",
    "# Momentum Optimizer\n",
    "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01,momentum=0.9).minimize(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Predictions\n",
    "Note that we are not getting the raw predictions, but the index of the maximally activated element in the prediction vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictions = tf.argmax(tf.nn.softmax(logits),axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Our Model to Classify Sentences\n",
    "\n",
    "Below we run our algorithm for 50 epochs. With the provided hyperparameters you should achieve around 90% accuracy on the test set. However you are welcome to play around with the hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialized\n",
      "\n",
      "Train Loss at Epoch 0: 1.75\n",
      "Test accuracy at Epoch 0: 13.333\n",
      "Train Loss at Epoch 1: 1.69\n",
      "Test accuracy at Epoch 1: 13.333\n",
      "Train Loss at Epoch 2: 1.63\n",
      "Test accuracy at Epoch 2: 26.875\n",
      "Train Loss at Epoch 3: 1.58\n",
      "Test accuracy at Epoch 3: 28.542\n",
      "Train Loss at Epoch 4: 1.53\n",
      "Test accuracy at Epoch 4: 30.417\n",
      "Train Loss at Epoch 5: 1.49\n",
      "Test accuracy at Epoch 5: 34.792\n",
      "Train Loss at Epoch 6: 1.45\n",
      "Test accuracy at Epoch 6: 40.833\n",
      "Train Loss at Epoch 7: 1.42\n",
      "Test accuracy at Epoch 7: 45.625\n",
      "Train Loss at Epoch 8: 1.39\n",
      "Test accuracy at Epoch 8: 47.083\n",
      "Train Loss at Epoch 9: 1.37\n",
      "Test accuracy at Epoch 9: 48.542\n",
      "Train Loss at Epoch 10: 1.34\n",
      "Test accuracy at Epoch 10: 48.750\n",
      "Train Loss at Epoch 11: 1.33\n",
      "Test accuracy at Epoch 11: 48.750\n",
      "Train Loss at Epoch 12: 1.29\n",
      "Test accuracy at Epoch 12: 50.208\n",
      "Train Loss at Epoch 13: 1.27\n",
      "Test accuracy at Epoch 13: 53.958\n",
      "Train Loss at Epoch 14: 1.23\n",
      "Test accuracy at Epoch 14: 57.708\n",
      "Train Loss at Epoch 15: 1.16\n",
      "Test accuracy at Epoch 15: 61.667\n",
      "Train Loss at Epoch 16: 1.10\n",
      "Test accuracy at Epoch 16: 65.208\n",
      "Train Loss at Epoch 17: 1.04\n",
      "Test accuracy at Epoch 17: 64.375\n",
      "Train Loss at Epoch 18: 0.98\n",
      "Test accuracy at Epoch 18: 63.750\n",
      "Train Loss at Epoch 19: 0.93\n",
      "Test accuracy at Epoch 19: 63.125\n",
      "Train Loss at Epoch 20: 0.88\n",
      "Test accuracy at Epoch 20: 63.333\n",
      "Train Loss at Epoch 21: 0.83\n",
      "Test accuracy at Epoch 21: 63.333\n",
      "Train Loss at Epoch 22: 0.80\n",
      "Test accuracy at Epoch 22: 63.542\n",
      "Train Loss at Epoch 23: 0.77\n",
      "Test accuracy at Epoch 23: 65.000\n",
      "Train Loss at Epoch 24: 0.74\n",
      "Test accuracy at Epoch 24: 69.583\n",
      "Train Loss at Epoch 25: 0.70\n",
      "Test accuracy at Epoch 25: 72.500\n",
      "Train Loss at Epoch 26: 0.67\n",
      "Test accuracy at Epoch 26: 75.208\n",
      "Train Loss at Epoch 27: 0.64\n",
      "Test accuracy at Epoch 27: 76.667\n",
      "Train Loss at Epoch 28: 0.61\n",
      "Test accuracy at Epoch 28: 78.125\n",
      "Train Loss at Epoch 29: 0.58\n",
      "Test accuracy at Epoch 29: 80.417\n",
      "Train Loss at Epoch 30: 0.55\n",
      "Test accuracy at Epoch 30: 82.083\n",
      "Train Loss at Epoch 31: 0.53\n",
      "Test accuracy at Epoch 31: 83.125\n",
      "Train Loss at Epoch 32: 0.51\n",
      "Test accuracy at Epoch 32: 83.542\n",
      "Train Loss at Epoch 33: 0.48\n",
      "Test accuracy at Epoch 33: 84.167\n",
      "Train Loss at Epoch 34: 0.47\n",
      "Test accuracy at Epoch 34: 85.000\n",
      "Train Loss at Epoch 35: 0.44\n",
      "Test accuracy at Epoch 35: 85.417\n",
      "Train Loss at Epoch 36: 0.43\n",
      "Test accuracy at Epoch 36: 85.625\n",
      "Train Loss at Epoch 37: 0.42\n",
      "Test accuracy at Epoch 37: 85.833\n",
      "Train Loss at Epoch 38: 0.41\n",
      "Test accuracy at Epoch 38: 86.667\n",
      "Train Loss at Epoch 39: 0.39\n",
      "Test accuracy at Epoch 39: 87.292\n",
      "Train Loss at Epoch 40: 0.38\n",
      "Test accuracy at Epoch 40: 87.292\n",
      "Train Loss at Epoch 41: 0.36\n",
      "Test accuracy at Epoch 41: 87.500\n",
      "Train Loss at Epoch 42: 0.36\n",
      "Test accuracy at Epoch 42: 87.917\n",
      "Train Loss at Epoch 43: 0.34\n",
      "Test accuracy at Epoch 43: 88.542\n",
      "Train Loss at Epoch 44: 0.33\n",
      "Test accuracy at Epoch 44: 88.542\n",
      "Train Loss at Epoch 45: 0.32\n",
      "Test accuracy at Epoch 45: 88.542\n",
      "Train Loss at Epoch 46: 0.32\n",
      "Test accuracy at Epoch 46: 88.333\n",
      "Train Loss at Epoch 47: 0.31\n",
      "Test accuracy at Epoch 47: 88.542\n",
      "Train Loss at Epoch 48: 0.30\n",
      "Test accuracy at Epoch 48: 88.542\n",
      "Train Loss at Epoch 49: 0.29\n",
      "Test accuracy at Epoch 49: 88.750\n"
     ]
    }
   ],
   "source": [
    "# With filter widths [3,5,7] and batch_size 32 the algorithm \n",
    "# achieves around ~90% accuracy on test dataset (50 epochs). \n",
    "# From batch sizes [16,32,64] I found 32 to give best performance\n",
    "\n",
    "session = tf.InteractiveSession()\n",
    "\n",
    "num_steps = 50 # Number of epochs the algorithm runs for\n",
    "\n",
    "# Initialize all variables\n",
    "tf.global_variables_initializer().run()\n",
    "print('Initialized\\n')\n",
    "\n",
    "# Define data batch generators for train and test data\n",
    "train_gen = BatchGenerator(batch_size,train_questions,train_labels)\n",
    "test_gen = BatchGenerator(batch_size,test_questions,test_labels)\n",
    "\n",
    "# How often do we compute the test accuracy\n",
    "test_interval = 1\n",
    "\n",
    "# Compute accuracy for a given set of predictions and labels\n",
    "def accuracy(labels,preds):\n",
    "    return np.sum(np.argmax(labels,axis=1)==preds)/labels.shape[0]\n",
    "\n",
    "# Running the algorithm\n",
    "for step in range(num_steps):\n",
    "    avg_loss = []\n",
    "    \n",
    "    # A single traverse through the whole training set\n",
    "    for tr_i in range((len(train_questions)//batch_size)-1):\n",
    "        # Get a batch of data\n",
    "        tr_inputs, tr_labels = train_gen.generate_batch()\n",
    "        # Optimize the network and compute the loss\n",
    "        l,_ = session.run([loss,optimizer],feed_dict={sent_inputs: tr_inputs, sent_labels: tr_labels})\n",
    "        avg_loss.append(l)\n",
    "\n",
    "    # Print average loss\n",
    "    print('Train Loss at Epoch %d: %.2f'%(step,np.mean(avg_loss)))\n",
    "    test_accuracy = []\n",
    "    \n",
    "    # Compute the test accuracy\n",
    "    if (step+1)%test_interval==0:        \n",
    "        for ts_i in range((len(test_questions)-1)//batch_size):\n",
    "            # Get a batch of test data\n",
    "            ts_inputs,ts_labels = test_gen.generate_batch()\n",
    "            # Get predictions for that batch\n",
    "            preds = session.run(predictions,feed_dict={sent_inputs: ts_inputs, sent_labels: ts_labels})\n",
    "            # Compute test accuracy\n",
    "            test_accuracy.append(accuracy(ts_labels,preds))\n",
    "        \n",
    "        # Display the mean test accuracy\n",
    "        print('Test accuracy at Epoch %d: %.3f'%(step,np.mean(test_accuracy)*100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
